{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5",
   "metadata": {},
   "source": [
    "# Week 2 exercise\n",
    "\n",
    "## MathXpert with tools integration\n",
    "\n",
    "- Provides the freedom to explore all the models available from the providers\n",
    "- Handling of multiple tools calling simultaneously\n",
    "- Efficiently run tools in parallel\n",
    "- Tool response, i.e. the `plot_function`, that does not require going back to the LLM\n",
    "- Uses the inbuilt logging package to allow the control of the verbosity of the logging, set to a higher level, like INFO, to reduce the noisy output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1070317-3ed9-4659-abe3-828943230e03",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import logging\n",
    "from enum import StrEnum\n",
    "from getpass import getpass\n",
    "from types import SimpleNamespace\n",
    "from typing import Callable\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output, Latex\n",
    "import gradio as gr\n",
    "\n",
    "load_dotenv(override=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99901b80",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(level=logging.WARNING)\n",
    "\n",
    "logger = logging.getLogger('mathxpert')\n",
    "logger.setLevel(logging.DEBUG)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f169118a-645e-44e1-9a98-4f561adfbb08",
   "metadata": {},
   "source": [
    "## Free Cloud Providers\n",
    "\n",
    "Grab your free API Keys from these generous sites:\n",
    "\n",
    "- https://openrouter.ai/\n",
    "- https://ollama.com/\n",
    "\n",
    ">**NOTE**: If you do not have a key for any provider, simply press ENTER to move on"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Provider(StrEnum):\n",
    "    OLLAMA = 'Ollama'\n",
    "    OPENROUTER = 'OpenRouter'\n",
    "\n",
    "clients: dict[Provider, OpenAI] = {}\n",
    "models: dict[Provider, list[str]] = {\n",
    "    Provider.OLLAMA: [],\n",
    "    Provider.OPENROUTER: [],\n",
    "}\n",
    "\n",
    "DEFAULT_PROVIDER = Provider.OLLAMA\n",
    "\n",
    "selection_state: dict[Provider, str | None] = {\n",
    "    Provider.OLLAMA: 'gpt-oss:20b',\n",
    "    Provider.OPENROUTER: 'openai/gpt-oss-20b:free',\n",
    "}\n",
    "\n",
    "def get_secret_in_google_colab(env_name: str) -> str:\n",
    "    try:\n",
    "      from google.colab import userdata\n",
    "      return userdata.get(env_name)\n",
    "    except Exception:\n",
    "      return ''\n",
    "      \n",
    "\n",
    "def get_secret(env_name: str) -> str:\n",
    "    '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n",
    "    key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n",
    "\n",
    "    if not key:\n",
    "        key = getpass(f'Enter {env_name}:').strip()\n",
    "\n",
    "    if key:\n",
    "        logger.info(f'✅ {env_name} provided')\n",
    "    else:\n",
    "        logger.warning(f'❌ {env_name} not provided')\n",
    "    return key.strip()\n",
    "\n",
    "\n",
    "if api_key := get_secret('OLLAMA_API_KEY'):\n",
    "    clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n",
    "\n",
    "if api_key := get_secret('OPENROUTER_API_KEY'):\n",
    "    clients[Provider.OPENROUTER] = OpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1')\n",
    "\n",
    "available_providers = [str(p) for p in clients.keys()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae1579b-7a02-459d-81c6-0f775d2a1410",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_provider, selected_model, client = '', '', None\n",
    "\n",
    "\n",
    "def get_desired_value_or_first_item(desire, options) -> str | None:\n",
    "    logger.debug(f'Pick {desire} from {options}')\n",
    "    selected = desire if desire in options else None\n",
    "    if selected:\n",
    "        return selected\n",
    "\n",
    "    return options[0] if options else None\n",
    "        \n",
    "try:\n",
    "    selected_provider = get_desired_value_or_first_item(DEFAULT_PROVIDER, available_providers)\n",
    "    client = clients.get(selected_provider)\n",
    "except Exception:\n",
    "    logger.warning(f'❌ no provider configured and everything else from here will FAIL 🤦, I know you know this already.')\n",
    "\n",
    "def load_models_if_needed(client: OpenAI, selected_provider):\n",
    "    global selected_model, models\n",
    "\n",
    "    if client and not models.get(selected_provider):\n",
    "        logging.info(f'📡 Fetching {selected_provider} models...')\n",
    "        \n",
    "        models[selected_provider] = [model.id for model in client.models.list()]\n",
    "        selected_model = get_desired_value_or_first_item(\n",
    "            selection_state[selected_provider], \n",
    "            models[selected_provider],\n",
    "        )\n",
    "\n",
    "load_models_if_needed(client, selected_provider)\n",
    "\n",
    "logger.info(f'ℹ️ Provider: {selected_provider} Model: {selected_model}, Client: {client}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e04675c2-1b81-4187-868c-c7112cd77e37",
   "metadata": {},
   "source": [
    "## Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_messages(question: str) -> list[dict[str, str]]:\n",
    "    \"\"\"Generate messages for the chat models.\"\"\"\n",
    "\n",
    "    system_prompt = r'''\n",
    "    You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n",
    "    practical usage to whip up the interest in learners.\n",
    "    \n",
    "    Explain step-by-step thoroughly how to solve a math problem. \n",
    "    - ALWAYS use `$$...$$` for mathematical expressions.\n",
    "    - NEVER use square brackets `[...]` to delimit math.\n",
    "    - Example: Instead of \"[x = 2]\", write \"$$x = 2$$\".\n",
    "    - You may use `\\\\[4pt]` inside matrices for spacing.\n",
    "    '''\n",
    "\n",
    "    return [\n",
    "        {'role': 'system', 'content': system_prompt },\n",
    "        {'role': 'user', 'content': question},\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caa51866-f433-4b9a-ab20-fff5fc3b7d63",
   "metadata": {},
   "source": [
    "## Tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a24c659a-5937-43b1-bb95-c0342f2786a9",
   "metadata": {},
   "source": [
    "### Tools Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f302f47-9a67-4410-ba16-56fa5a731c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel, Field\n",
    "from openai.types.shared_params import FunctionDefinition\n",
    "import sympy as sp\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import io\n",
    "import base64\n",
    "import random\n",
    "\n",
    "class ToolInput(BaseModel):\n",
    "    pass\n",
    "    \n",
    "class GetCurrentDateTimeInput(ToolInput):\n",
    "    timezone: str = Field(default=\"UTC\", description=\"Timezone name, e.g., 'UTC' or 'Africa/Accra'\")\n",
    "\n",
    "\n",
    "def get_current_datetime(req: GetCurrentDateTimeInput):\n",
    "    '''Returns the current date and time in the specified timezone.'''\n",
    "    from zoneinfo import ZoneInfo\n",
    "\n",
    "    try:\n",
    "        from datetime import datetime\n",
    "        tz = ZoneInfo(req.timezone)\n",
    "        dt = datetime.now(tz)\n",
    "        return {\n",
    "            \"date\": dt.strftime(\"%Y-%m-%d\"),\n",
    "            \"time\": dt.strftime(\"%H:%M:%S %Z\"),\n",
    "        } \n",
    "    except:\n",
    "        return {\"error\": f\"Invalid timezone: {req.timezone}\"}\n",
    "\n",
    "\n",
    "class GetTemperatureInput(ToolInput):\n",
    "    pass\n",
    "\n",
    "def get_temperature(req: GetTemperatureInput) -> float:\n",
    "    '''Returns the current temperature in degree celsius'''\n",
    "    return random.randint(-30, 70)\n",
    "\n",
    "\n",
    "class PlotFunctionInput(ToolInput):\n",
    "    expression: str = Field(description=\"Mathematical expression to plot, e.g., 'sin(x)'\")\n",
    "    x_min: float = Field(default=-10, description=\"Minimum x value\")\n",
    "    x_max: float = Field(default=10, description=\"Maximum x value\")\n",
    "\n",
    "\n",
    "def plot_function(req: PlotFunctionInput) -> dict[str, any]:\n",
    "    '''Plots a mathematical function and returns image data.'''\n",
    "    try:\n",
    "        x = sp.symbols('x')\n",
    "        expr = sp.sympify(req.expression)\n",
    "        lambdified = sp.lambdify(x, expr, 'numpy')\n",
    "        \n",
    "        x_vals = np.linspace(req.x_min, req.x_max, 400)\n",
    "        y_vals = lambdified(x_vals)\n",
    "        \n",
    "        plt.figure(figsize=(10, 6))\n",
    "        plt.plot(x_vals, y_vals, 'b-', linewidth=2)\n",
    "        plt.grid(True, alpha=0.3)\n",
    "        plt.title(f\"Plot of ${sp.latex(expr)}$\", fontsize=14)\n",
    "        plt.xlabel('x', fontsize=12)\n",
    "        plt.ylabel('f(x)', fontsize=12)\n",
    "        \n",
    "\n",
    "        buf = io.BytesIO()\n",
    "        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')\n",
    "        plt.close()\n",
    "        buf.seek(0)\n",
    "        img_str = base64.b64encode(buf.read()).decode()\n",
    "        \n",
    "        return {\n",
    "            \"plot_image\": f\"data:image/png;base64,{img_str}\",\n",
    "            \"expression\": req.expression,\n",
    "            \"x_range\": [req.x_min, req.x_max]\n",
    "        }\n",
    "    except Exception as e:\n",
    "        return {\"error\": f\"Could not plot function: {str(e)}\"}\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fae3ef71-f6cd-4894-ae55-9f4f8dd2a1cd",
   "metadata": {},
   "source": [
    "### Tools registration & execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f18bc9f-f8d1-4208-a3d7-e4e911034572",
   "metadata": {},
   "outputs": [],
   "source": [
    "from concurrent.futures import ThreadPoolExecutor\n",
    "\n",
    "class ToolManager:\n",
    "    def __init__(self):\n",
    "        self._tools = []\n",
    "        self._tools_map: dict[str, tuple[Callable, ToolInput]] = {}\n",
    "\n",
    "    def register_tool[T: ToolInput](self, fn: Callable, fn_input: T):\n",
    "        self._tools.append({\n",
    "            \"type\": \"function\",\n",
    "            \"function\": FunctionDefinition(\n",
    "                name=fn.__name__,\n",
    "                description=fn.__doc__,\n",
    "                parameters=fn_input.model_json_schema() if fn_input else None,\n",
    "            )\n",
    "        })\n",
    "    \n",
    "        self._tools_map[fn.__name__] = (fn, fn_input)\n",
    "\n",
    "    def _run_single_tool(self, tool_call) -> dict[str, str] | None:\n",
    "        if not tool_call.id:\n",
    "            return None\n",
    "    \n",
    "        fn, fn_input = self._tools_map.get(tool_call.function.name)\n",
    "        args = tool_call.function.arguments\n",
    "        try:\n",
    "            if args:\n",
    "                result = fn(fn_input(**json.loads(args))) if fn_input else fn()\n",
    "            else:\n",
    "                result = fn(fn_input()) if fn_input else fn()\n",
    "    \n",
    "            logger.debug(f'Tool run result: {result}')\n",
    "    \n",
    "            return {\n",
    "                'role': 'tool',\n",
    "                'tool_call_id': tool_call.id,\n",
    "                'content': json.dumps(result),\n",
    "            }\n",
    "        except Exception as e:\n",
    "            logger.error(f'Tool execution failed: {e}', extra={'name': tool_call.function.name})\n",
    "            return None\n",
    "\n",
    "    def run(self, tool_calls) -> list[dict[str, str]]:\n",
    "        if not tool_calls:\n",
    "            return []\n",
    "\n",
    "        logger.debug(tool_calls)\n",
    "\n",
    "        tool_messages = []\n",
    "        \n",
    "        with ThreadPoolExecutor() as executor:\n",
    "            futures = [executor.submit(self._run_single_tool, tool_call) for tool_call in tool_calls]\n",
    "            \n",
    "            for future in futures:\n",
    "                result = future.result()\n",
    "                if result:\n",
    "                    tool_messages.append(result)\n",
    "        \n",
    "        return tool_messages\n",
    "\n",
    "    @property\n",
    "    def tools(self) -> list[any]:\n",
    "        return self._tools\n",
    "\n",
    "    def dump_tools(self) -> str:\n",
    "        return json.dumps(self._tools, indent=True)\n",
    "\n",
    "    \n",
    "tool_manager = ToolManager()\n",
    "\n",
    "tool_manager.register_tool(get_current_datetime, GetCurrentDateTimeInput)\n",
    "tool_manager.register_tool(get_temperature, GetTemperatureInput)\n",
    "tool_manager.register_tool(plot_function, PlotFunctionInput)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b2e0634-de5d-45f6-a8d4-569e04d14a00",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.debug(tool_manager.dump_tools())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bde4cd2a-b681-4b78-917c-d970c264b151",
   "metadata": {},
   "source": [
    "## Interaction with LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538",
   "metadata": {},
   "outputs": [],
   "source": [
    "# handle = display(None, display_id=True)\n",
    "\n",
    "def ask(client: OpenAI | None, model: str, question: str, max_tool_turns=5):\n",
    "    if client is None:\n",
    "        logger.warning('You should have provided the API Keys you know. Fix 🔧 this and try again ♻️.')\n",
    "        return\n",
    "\n",
    "    try:\n",
    "        logger.debug(f'# Tools: {len(tool_manager.tools)}')\n",
    "\n",
    "        messages = get_messages(question=question)\n",
    "\n",
    "        for turn in range(max_tool_turns):\n",
    "            logger.debug(f'Turn: {turn}')\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                tools=tool_manager.tools,\n",
    "                stream=True,\n",
    "            )\n",
    "    \n",
    "            current_message = {}\n",
    "            tool_calls_accumulator = {}\n",
    "            \n",
    "            output = ''\n",
    "            call_id = None\n",
    "            \n",
    "            for chunk in response:\n",
    "                delta = chunk.choices[0].delta\n",
    "\n",
    "                logger.debug(f' ✨  {chunk.choices[0]}')\n",
    "                if content := delta.content:\n",
    "                    output += content\n",
    "                    yield output\n",
    "\n",
    "                if tool_calls := delta.tool_calls:\n",
    "                    for tool_chunk in tool_calls:\n",
    "                        print('x' * 50)\n",
    "                        print(tool_chunk)\n",
    "\n",
    "                        if tool_chunk.id and call_id != tool_chunk.id:\n",
    "                            call_id = tool_chunk.id\n",
    "\n",
    "                        print(f'Call ID: {call_id}')\n",
    "                        # Streams of arguments don't come with the call id\n",
    "                        # if not call_id:\n",
    "                        #     continue\n",
    "\n",
    "                        if call_id not in tool_calls_accumulator:\n",
    "                            # tool_calls_accumulator[call_id] = {\n",
    "                            #     'id': call_id,\n",
    "                            #     'function': {'name': '', 'arguments': ''}\n",
    "                            # }\n",
    "                            tool_calls_accumulator[call_id] = SimpleNamespace(\n",
    "                                id=call_id,\n",
    "                                function=SimpleNamespace(name='', arguments='')\n",
    "                            )\n",
    "\n",
    "                        if tool_chunk.function.name:\n",
    "                            tool_calls_accumulator[call_id].function.name += tool_chunk.function.name\n",
    "                            \n",
    "                        if tool_chunk.function.arguments:\n",
    "                            tool_calls_accumulator[call_id].function.arguments += tool_chunk.function.arguments\n",
    "\n",
    "            if finish_reason := chunk.choices[0].finish_reason:\n",
    "                logger.debug('🧠 LLM interaction ended. Reason: {finish_reason}')\n",
    "\n",
    "            final_tool_calls = list(tool_calls_accumulator.values())\n",
    "            if final_tool_calls:\n",
    "                logger.debug(f'Final tools to call {final_tool_calls}')\n",
    "\n",
    "                tool_call_message = {\n",
    "                    'role': 'assistant',\n",
    "                    'content': None,\n",
    "                    'tool_calls': json.loads(json.dumps(final_tool_calls, default=lambda o: o.__dict__))\n",
    "                }\n",
    "\n",
    "                messages.append(tool_call_message)\n",
    "                tool_messages = tool_manager.run(final_tool_calls)\n",
    "\n",
    "                if tool_messages:\n",
    "                    for tool_msg in tool_messages:\n",
    "                        try:\n",
    "                            data = json.loads(tool_msg['content'])\n",
    "                            if 'plot_image' in data:\n",
    "                                logger.debug('We have a plot')\n",
    "                                yield f'<img src=\"{data[\"plot_image\"]}\" style=\"max-width: 100%; height: auto; border: 1px solid #ccc; border-radius: 5px;\">'\n",
    "                                return\n",
    "                        except:\n",
    "                            pass\n",
    "                    messages.extend(tool_messages)\n",
    "            else:\n",
    "                return\n",
    "                       \n",
    "    except Exception as e:\n",
    "        logger.error(f'🔥 An error occurred during the interaction with the LLM: {e}', exc_info=True)\n",
    "        return str(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda786d3-5add-4bd1-804d-13eff60c3d1a",
   "metadata": {},
   "source": [
    "### Verify streaming behaviour"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(selected_provider, selected_model)\n",
    "# print(client)\n",
    "# for o in ask(client, selected_model, 'What is the time?'):\n",
    "# for o in ask(client, selected_model, 'What is the temperature?'):\n",
    "# for o in ask(client, selected_model, 'What is the time and the temperature?'):\n",
    "# for o in ask(client, selected_model, 'Plot a for the expression sin(x)'):\n",
    "for o in ask(client, selected_model, 'Plot a graph of y = x**2'):\n",
    "    print(o)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27230463",
   "metadata": {},
   "source": [
    "## Build Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50fc3577",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message: str, history: list[dict], selected_provider: str, model_selector: str):\n",
    "    # NOTE: I'm not interesting in maintaining a conversation\n",
    "    response = ask(client, selected_model, message)\n",
    "\n",
    "    for chunk in response:\n",
    "        yield chunk\n",
    "\n",
    "def on_provider_change(change):\n",
    "    global selected_provider, client, models\n",
    "    logger.info(f'Provider changed to {change}')\n",
    "    selected_provider = change\n",
    "    client = clients.get(selected_provider)\n",
    "    load_models_if_needed(client, selected_provider)\n",
    "\n",
    "    return gr.Dropdown(\n",
    "        choices=models.get(selected_provider, []),\n",
    "        value=selection_state[selected_provider],\n",
    "        interactive=True,\n",
    "    )\n",
    "\n",
    "\n",
    "def on_model_change(change):\n",
    "    global selected_provider, selected_model, selection_state\n",
    "\n",
    "    selected_model = change\n",
    "    selection_state[selected_provider] = selected_model\n",
    "    logger.info(f'👉 Selected model: {selected_model}')\n",
    "\n",
    "\n",
    "with gr.Blocks(title='MathXpert', fill_width=True,  \n",
    "    \n",
    "  ) as ui:\n",
    "    def get_value_if_exist(v, ls) -> str:\n",
    "        print(ls)\n",
    "        selected = v if v in ls else None\n",
    "        if selected:\n",
    "            return selected\n",
    "\n",
    "        return ls[0] if ls else None\n",
    "\n",
    "    with gr.Row():\n",
    "        provider_selector = gr.Dropdown(\n",
    "            choices=available_providers, \n",
    "            value=get_desired_value_or_first_item(selected_provider, available_providers),\n",
    "            label='Provider',\n",
    "        )\n",
    "        model_selector = gr.Dropdown(\n",
    "            choices=models[selected_provider],\n",
    "            value=get_desired_value_or_first_item(selection_state[selected_provider], models[selected_provider]),\n",
    "            label='Model',\n",
    "        )\n",
    "    \n",
    "    provider_selector.change(fn=on_provider_change, inputs=provider_selector, outputs=model_selector)\n",
    "    model_selector.change(fn=on_model_change, inputs=model_selector)\n",
    "\n",
    "    examples = [\n",
    "        ['Where can substitutions be applied in real life?', None, None],\n",
    "        ['Give 1 differential equation question and solve it', None, None],\n",
    "        ['Plot x**2 - 3x', None, None],\n",
    "        ['What is the time now?', None, None],\n",
    "        ['What is the temperature?', None, None],\n",
    "        ['Tell me the time and the temperature now', None, None],\n",
    "    ]\n",
    "\n",
    "        \n",
    "    gr.ChatInterface(\n",
    "        fn=chat, \n",
    "        type='messages', \n",
    "        chatbot=gr.Chatbot(type='messages', height='75vh', resizable=True),\n",
    "        additional_inputs=[provider_selector, model_selector],\n",
    "        examples=examples,\n",
    "    )\n",
    "\n",
    "ui.launch()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
